#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2019/7/26 下午3:48
# @Author  : fugang_le
# @Software: PyCharm


from rank_bm25 import BM25Okapi


def bm25_():
    corpus = [['流产', '流产假', '产假', '需要', '提供', '什么', '材料'], ['陪产假', '产假', '要', '交', '什么', '材料'],
              ['丧假', '需要', '提供', '哪些', '证明', '材料'], ['休', '病假', '要', '交', '什么', '材料'],
              ['已', '转', '深', '户', '我', '需要', '提交', '什么', '材料'], ['如何', '休', '陪产假', '产假'],
              ['哪些', '休假', '是', '需要', '提交', '材料', '的'], ['休假', '材料', '提交', '交到', '哪'],
              ['陪产假', '产假', '有', '多少', '天'],
              ['如何', '何提', '提交', '材料']]

    bm25 = BM25Okapi(corpus)

    query = ["陪产假", "产假", "需要", "提供", "什么", "材料"]

    doc_scores = bm25.get_scores(query)
    print("doc scores:", doc_scores)

    doc_top = bm25.get_top_n(query, corpus, n=2)
    print('doc top:', doc_top)


bm25_()






import numpy as np
from collections import Counter


class BM25_Model(object):
    def __init__(self, documents_list, k1=2, k2=1, b=0.5):
        self.documents_list = documents_list
        self.documents_number = len(documents_list)
        self.avg_documents_len = sum([len(document) for document in documents_list]) / self.documents_number
        self.f = []
        self.idf = {}
        self.k1 = k1
        self.k2 = k2
        self.b = b
        self.init()

    def init(self):
        df = {}
        for document in self.documents_list:
            temp = {}
            for word in document:
                temp[word] = temp.get(word, 0) + 1
            self.f.append(temp)
            for key in temp.keys():
                df[key] = df.get(key, 0) + 1
        for key, value in df.items():
            self.idf[key] = np.log((self.documents_number - value + 0.5) / (value + 0.5))

    def get_score(self, index, query):
        score = 0.0
        document_len = len(self.f[index])
        qf = Counter(query)
        for q in query:
            if q not in self.f[index]:
                continue
            score += self.idf[q] * (self.f[index][q] * (self.k1 + 1) / (
                        self.f[index][q] + self.k1 * (1 - self.b + self.b * document_len / self.avg_documents_len))) * (
                                 qf[q] * (self.k2 + 1) / (qf[q] + self.k2))

        return score

    def get_documents_score(self, query):
        score_list = []
        for i in range(self.documents_number):
            score_list.append(self.get_score(i, query))
        return score_list


document_list = ["行政机关强行解除行政协议造成损失，如何索取赔偿？",
                 "借钱给朋友到期不还得什么时候可以起诉？怎么起诉？",
                 "我在微信上被骗了，请问被骗多少钱才可以立案？",
                 "公民对于选举委员会对选民的资格申诉的处理决定不服，能不能去法院起诉吗？",
                 "有人走私两万元，怎么处置他？",
                 "法律上餐具、饮具集中消毒服务单位的责任是不是对消毒餐具、饮具进行检验？"]

import jieba
document_list = [list(jieba.cut(doc)) for doc in document_list]
bm25_model = BM25_Model(document_list)
print(bm25_model.documents_list)
print(bm25_model.documents_number)
print(bm25_model.avg_documents_len)
print(bm25_model.f)
print(bm25_model.idf)

query = "走私了两万元，在法律上应该怎么量刑？"
query = list(jieba.cut(query))
scores = bm25_model.get_documents_score(query)

